package cn.doitedu.recomment.cb

import java.io.File

import cn.doitedu.commons.utils.{FileUtils, SparkUtil}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.{HashingTF, IDF, StringIndexer, VectorSlicer}
import org.apache.spark.ml.linalg
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.DataTypes

import scala.collection.mutable

/**
  * @author: 余辉
  * @blog: https://blog.csdn.net/silentwolfyh
  * @create: 2019/10/21
  * @description:
  * 需求：物品相似度计算
  *
  * 步鄹：
  * 1、pid,price,cat1,cat2,cat3,kwds =》 pid，【向量】
  * 2、通过余弦相似度求出两个物品的相似度
  *
  * 详细步鄹：
  * 1、读取csv文件，提取头文件
  * 2、将【kwds】进行切分,通过 HashingTF 转为词频，之后通过IDF进行转换为逆文档频率
  * 3、将【cat1,cat2,cat3】三列的内容，通过 StringIndexer 转为索引值
  * 4、将【price,cat1,cat2,cat3,kwds】整合成一个特征向量==》 （pid ， features）
  * 5、通过余弦相似度求出两个物品的相似度
  **/
object ItemSimilarity {
  def main(args: Array[String]): Unit = {

    // 1、建立session连接
    Logger.getLogger("org").setLevel(Level.WARN)
    val spark: SparkSession = SparkUtil.getSparkSession(this.getClass.getSimpleName)
    import spark.implicits._
    import org.apache.spark.sql.functions._

    // 2、提取原始，注册成临时表
    val item: DataFrame = spark.read.option("header", true).option("inferSchema", true).csv("rec_system/data/ui_rate/item.profile2.dat")
    val items = item.rdd.map({
      case Row(pid: String, level: Double, cat1: String, cat2: String, cat3: String, kwds: String) => {
        (pid, level, cat1, cat2, cat3, kwds.split(" "))
      }
    }).toDF("pid", "level", "cat1", "cat2", "cat3", "words")

    /**
      * 4、数据处理；1）处理 cat1 ,cat2 ,cat3；2）处理 kwds
      */
    val idx1: StringIndexer = new StringIndexer()
      .setInputCol("cat1")
      .setOutputCol("c1")

    val idx2: StringIndexer = new StringIndexer()
      .setInputCol("cat2")
      .setOutputCol("c2")

    val idx3: StringIndexer = new StringIndexer()
      .setInputCol("cat3")
      .setOutputCol("c3")

    val idx1Df: DataFrame = idx1.fit(items).transform(items).drop("cat1")
    val idx2Df: DataFrame = idx2.fit(idx1Df).transform(idx1Df).drop("cat2")
    val idx3Df: DataFrame = idx3.fit(idx2Df).transform(idx2Df).drop("cat3")

    val hashingTF: HashingTF = new HashingTF().setInputCol("words").setNumFeatures(10000).setOutputCol("tf")
    val hashingDf: DataFrame = hashingTF.transform(idx3Df).drop("words")
    val idfDf: DataFrame = new IDF().setInputCol("tf").setOutputCol("idf").fit(hashingDf).transform(hashingDf).drop("tf")
    idfDf.show(10, false)

    // 5、将上面处理好的每个人的特征，整合成一个向量，注册为UDF函数 combineVec
    val combineVec: UserDefinedFunction = udf(
      (arr: mutable.WrappedArray[Double], vec: linalg.Vector) => {
        // 5-1、将向量先转回数组
        val vec1: Array[Double] = vec.toArray
        // 5-2、然后拼接两个数组
        val okVec: mutable.WrappedArray[Double] = arr.++(vec1)
        Vectors.dense(okVec.toArray).toSparse
      })

    // 6、整合所有特征为一个向量 ： （数组和向量）整合 as "features")
    val vecDf: DataFrame = idfDf.select('pid, combineVec(array('level, 'c1, 'c2, 'c3), 'idf).as("features"))
    vecDf.show(10, false)

    // 7、将物品和物品进行关联
    val joinedVec: DataFrame = vecDf.join(vecDf.toDF("pid2", "features2"), 'pid < 'pid2, "cross")

    // 8、求每两个物品之间的余弦相似度 ，UDF函数 cosSim
    val cosSim: UserDefinedFunction =
      udf(
        (v1: linalg.Vector, v2: linalg.Vector) => {
          val fenmu1: Double = v1.toArray.map(Math.pow(_, 2)).sum
          val fenmu2: Double = v2.toArray.map(Math.pow(_, 2)).sum

          val fenzi: Double = v1.toArray.zip(v2.toArray).map(tp => tp._1 * tp._2).sum
          fenzi / Math.pow(fenmu1 * fenmu2, 0.5)
        }
      )

    // 9、通过 余弦相似度 获取结果
    val itemSimilarity: DataFrame = joinedVec.select('pid, 'pid2, cosSim('features, 'features2) as "sim")
    itemSimilarity.show(50, false)

    FileUtils.deleteDir(new File("rec_system/outputdata/cb_out/item_item"))
    itemSimilarity.coalesce(1).write.parquet("rec_system/outputdata/cb_out/item_item")

    // 10、关闭Spark
    spark.close()
  }
}
